{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "ANN Classifier.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TOOEhkudSJVS"
      },
      "source": [
        "## ANN Clasifier"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "69q_-hZESMq6"
      },
      "source": [
        "<b>Importing Required Libraries"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "D1zRaRb4QVSj"
      },
      "source": [
        "import numpy as np # linear algebra\n",
        "import pandas as pd # data processing\n",
        "import seaborn as sns\n",
        "import tensorflow as tf\n",
        "import matplotlib.pyplot as plt\n",
        "from sklearn.preprocessing import StandardScaler\n"
      ],
      "execution_count": 87,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "B9VTpNisSYsc"
      },
      "source": [
        "<b>Mouting Google Drive to read data"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gzv7KtENRhgt",
        "outputId": "59caf0b5-40e2-478a-e3bb-0eef3dae1ac8"
      },
      "source": [
        "from google.colab import drive\n",
        "drive.mount(\"/content/gdrive\")"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Mounted at /content/gdrive\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WjcRIygUSC1c"
      },
      "source": [
        "inFile = \"/content/gdrive/My Drive/Dataset/winequality-red.csv\"\n",
        "data = pd.read_csv(inFile)"
      ],
      "execution_count": 114,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 203
        },
        "id": "QWuAMtToS61k",
        "outputId": "110657be-e7b9-4601-adae-289a08f89cc0"
      },
      "source": [
        "data.head()"
      ],
      "execution_count": 115,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>fixed acidity</th>\n",
              "      <th>volatile acidity</th>\n",
              "      <th>citric acid</th>\n",
              "      <th>residual sugar</th>\n",
              "      <th>chlorides</th>\n",
              "      <th>free sulfur dioxide</th>\n",
              "      <th>total sulfur dioxide</th>\n",
              "      <th>density</th>\n",
              "      <th>pH</th>\n",
              "      <th>sulphates</th>\n",
              "      <th>alcohol</th>\n",
              "      <th>quality</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>7.4</td>\n",
              "      <td>0.70</td>\n",
              "      <td>0.00</td>\n",
              "      <td>1.9</td>\n",
              "      <td>0.076</td>\n",
              "      <td>11.0</td>\n",
              "      <td>34.0</td>\n",
              "      <td>0.9978</td>\n",
              "      <td>3.51</td>\n",
              "      <td>0.56</td>\n",
              "      <td>9.4</td>\n",
              "      <td>5</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>7.8</td>\n",
              "      <td>0.88</td>\n",
              "      <td>0.00</td>\n",
              "      <td>2.6</td>\n",
              "      <td>0.098</td>\n",
              "      <td>25.0</td>\n",
              "      <td>67.0</td>\n",
              "      <td>0.9968</td>\n",
              "      <td>3.20</td>\n",
              "      <td>0.68</td>\n",
              "      <td>9.8</td>\n",
              "      <td>5</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>7.8</td>\n",
              "      <td>0.76</td>\n",
              "      <td>0.04</td>\n",
              "      <td>2.3</td>\n",
              "      <td>0.092</td>\n",
              "      <td>15.0</td>\n",
              "      <td>54.0</td>\n",
              "      <td>0.9970</td>\n",
              "      <td>3.26</td>\n",
              "      <td>0.65</td>\n",
              "      <td>9.8</td>\n",
              "      <td>5</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>11.2</td>\n",
              "      <td>0.28</td>\n",
              "      <td>0.56</td>\n",
              "      <td>1.9</td>\n",
              "      <td>0.075</td>\n",
              "      <td>17.0</td>\n",
              "      <td>60.0</td>\n",
              "      <td>0.9980</td>\n",
              "      <td>3.16</td>\n",
              "      <td>0.58</td>\n",
              "      <td>9.8</td>\n",
              "      <td>6</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>7.4</td>\n",
              "      <td>0.70</td>\n",
              "      <td>0.00</td>\n",
              "      <td>1.9</td>\n",
              "      <td>0.076</td>\n",
              "      <td>11.0</td>\n",
              "      <td>34.0</td>\n",
              "      <td>0.9978</td>\n",
              "      <td>3.51</td>\n",
              "      <td>0.56</td>\n",
              "      <td>9.4</td>\n",
              "      <td>5</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "   fixed acidity  volatile acidity  citric acid  ...  sulphates  alcohol  quality\n",
              "0            7.4              0.70         0.00  ...       0.56      9.4        5\n",
              "1            7.8              0.88         0.00  ...       0.68      9.8        5\n",
              "2            7.8              0.76         0.04  ...       0.65      9.8        5\n",
              "3           11.2              0.28         0.56  ...       0.58      9.8        6\n",
              "4            7.4              0.70         0.00  ...       0.56      9.4        5\n",
              "\n",
              "[5 rows x 12 columns]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 115
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pEGXf0iBS9Y0"
      },
      "source": [
        "X = data.drop(\"quality\", axis = 1)\n",
        "y = data[\"quality\"]"
      ],
      "execution_count": 116,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IvIM5bVnXB_v"
      },
      "source": [
        "<b>Feature Selector"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pe4010o5f5vB"
      },
      "source": [
        "selected_features_index = [1,1,1,1,1,1,1,1,1,1,1]\n",
        "selected_features = []\n",
        "cols = X.columns\n",
        "for index in range(len(selected_features_index)):\n",
        "  if selected_features_index[index] == 1:\n",
        "    selected_features.append(cols[index])  \n",
        "\n",
        "for i in range(len(cols)):\n",
        "  k = cols[i]\n",
        "  if k not in selected_features:\n",
        "    X = X.drop(columns=\n",
        "                     k)\n",
        "scaler = StandardScaler()\n",
        "scaler.fit(X)\n",
        "X_scaled = scaler.transform(X) "
      ],
      "execution_count": 117,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 203
        },
        "id": "25Mv770Ala2F",
        "outputId": "777bb46f-7cff-43ed-b9c9-601ebb9384e5"
      },
      "source": [
        "X.head()"
      ],
      "execution_count": 118,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>fixed acidity</th>\n",
              "      <th>volatile acidity</th>\n",
              "      <th>citric acid</th>\n",
              "      <th>residual sugar</th>\n",
              "      <th>chlorides</th>\n",
              "      <th>free sulfur dioxide</th>\n",
              "      <th>total sulfur dioxide</th>\n",
              "      <th>density</th>\n",
              "      <th>pH</th>\n",
              "      <th>sulphates</th>\n",
              "      <th>alcohol</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>7.4</td>\n",
              "      <td>0.70</td>\n",
              "      <td>0.00</td>\n",
              "      <td>1.9</td>\n",
              "      <td>0.076</td>\n",
              "      <td>11.0</td>\n",
              "      <td>34.0</td>\n",
              "      <td>0.9978</td>\n",
              "      <td>3.51</td>\n",
              "      <td>0.56</td>\n",
              "      <td>9.4</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>7.8</td>\n",
              "      <td>0.88</td>\n",
              "      <td>0.00</td>\n",
              "      <td>2.6</td>\n",
              "      <td>0.098</td>\n",
              "      <td>25.0</td>\n",
              "      <td>67.0</td>\n",
              "      <td>0.9968</td>\n",
              "      <td>3.20</td>\n",
              "      <td>0.68</td>\n",
              "      <td>9.8</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>7.8</td>\n",
              "      <td>0.76</td>\n",
              "      <td>0.04</td>\n",
              "      <td>2.3</td>\n",
              "      <td>0.092</td>\n",
              "      <td>15.0</td>\n",
              "      <td>54.0</td>\n",
              "      <td>0.9970</td>\n",
              "      <td>3.26</td>\n",
              "      <td>0.65</td>\n",
              "      <td>9.8</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>11.2</td>\n",
              "      <td>0.28</td>\n",
              "      <td>0.56</td>\n",
              "      <td>1.9</td>\n",
              "      <td>0.075</td>\n",
              "      <td>17.0</td>\n",
              "      <td>60.0</td>\n",
              "      <td>0.9980</td>\n",
              "      <td>3.16</td>\n",
              "      <td>0.58</td>\n",
              "      <td>9.8</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>7.4</td>\n",
              "      <td>0.70</td>\n",
              "      <td>0.00</td>\n",
              "      <td>1.9</td>\n",
              "      <td>0.076</td>\n",
              "      <td>11.0</td>\n",
              "      <td>34.0</td>\n",
              "      <td>0.9978</td>\n",
              "      <td>3.51</td>\n",
              "      <td>0.56</td>\n",
              "      <td>9.4</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "   fixed acidity  volatile acidity  citric acid  ...    pH  sulphates  alcohol\n",
              "0            7.4              0.70         0.00  ...  3.51       0.56      9.4\n",
              "1            7.8              0.88         0.00  ...  3.20       0.68      9.8\n",
              "2            7.8              0.76         0.04  ...  3.26       0.65      9.8\n",
              "3           11.2              0.28         0.56  ...  3.16       0.58      9.8\n",
              "4            7.4              0.70         0.00  ...  3.51       0.56      9.4\n",
              "\n",
              "[5 rows x 11 columns]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 118
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "lqMTSe9alD5Z",
        "outputId": "12f1127f-300d-4823-91ed-b0a4676f977f"
      },
      "source": [
        "y.head()"
      ],
      "execution_count": 119,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0    5\n",
              "1    5\n",
              "2    5\n",
              "3    6\n",
              "4    5\n",
              "Name: quality, dtype: int64"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 119
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "i0OvanLlVieb"
      },
      "source": [
        "from sklearn.model_selection import train_test_split\n",
        "X_train, X_test, y_train, y_test = train_test_split(X_scaled,y, test_size=0.3, random_state=42)"
      ],
      "execution_count": 120,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 203
        },
        "id": "8x4FkR8ZWf2Y",
        "outputId": "db252aae-4b5c-4ebe-9498-e83b2345eff8"
      },
      "source": [
        "data.head()"
      ],
      "execution_count": 121,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>fixed acidity</th>\n",
              "      <th>volatile acidity</th>\n",
              "      <th>citric acid</th>\n",
              "      <th>residual sugar</th>\n",
              "      <th>chlorides</th>\n",
              "      <th>free sulfur dioxide</th>\n",
              "      <th>total sulfur dioxide</th>\n",
              "      <th>density</th>\n",
              "      <th>pH</th>\n",
              "      <th>sulphates</th>\n",
              "      <th>alcohol</th>\n",
              "      <th>quality</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>7.4</td>\n",
              "      <td>0.70</td>\n",
              "      <td>0.00</td>\n",
              "      <td>1.9</td>\n",
              "      <td>0.076</td>\n",
              "      <td>11.0</td>\n",
              "      <td>34.0</td>\n",
              "      <td>0.9978</td>\n",
              "      <td>3.51</td>\n",
              "      <td>0.56</td>\n",
              "      <td>9.4</td>\n",
              "      <td>5</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>7.8</td>\n",
              "      <td>0.88</td>\n",
              "      <td>0.00</td>\n",
              "      <td>2.6</td>\n",
              "      <td>0.098</td>\n",
              "      <td>25.0</td>\n",
              "      <td>67.0</td>\n",
              "      <td>0.9968</td>\n",
              "      <td>3.20</td>\n",
              "      <td>0.68</td>\n",
              "      <td>9.8</td>\n",
              "      <td>5</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>7.8</td>\n",
              "      <td>0.76</td>\n",
              "      <td>0.04</td>\n",
              "      <td>2.3</td>\n",
              "      <td>0.092</td>\n",
              "      <td>15.0</td>\n",
              "      <td>54.0</td>\n",
              "      <td>0.9970</td>\n",
              "      <td>3.26</td>\n",
              "      <td>0.65</td>\n",
              "      <td>9.8</td>\n",
              "      <td>5</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>11.2</td>\n",
              "      <td>0.28</td>\n",
              "      <td>0.56</td>\n",
              "      <td>1.9</td>\n",
              "      <td>0.075</td>\n",
              "      <td>17.0</td>\n",
              "      <td>60.0</td>\n",
              "      <td>0.9980</td>\n",
              "      <td>3.16</td>\n",
              "      <td>0.58</td>\n",
              "      <td>9.8</td>\n",
              "      <td>6</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>7.4</td>\n",
              "      <td>0.70</td>\n",
              "      <td>0.00</td>\n",
              "      <td>1.9</td>\n",
              "      <td>0.076</td>\n",
              "      <td>11.0</td>\n",
              "      <td>34.0</td>\n",
              "      <td>0.9978</td>\n",
              "      <td>3.51</td>\n",
              "      <td>0.56</td>\n",
              "      <td>9.4</td>\n",
              "      <td>5</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "   fixed acidity  volatile acidity  citric acid  ...  sulphates  alcohol  quality\n",
              "0            7.4              0.70         0.00  ...       0.56      9.4        5\n",
              "1            7.8              0.88         0.00  ...       0.68      9.8        5\n",
              "2            7.8              0.76         0.04  ...       0.65      9.8        5\n",
              "3           11.2              0.28         0.56  ...       0.58      9.8        6\n",
              "4            7.4              0.70         0.00  ...       0.56      9.4        5\n",
              "\n",
              "[5 rows x 12 columns]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 121
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "UlASXKVflhLa"
      },
      "source": [
        "tf.random.set_seed(42)\n",
        "\n",
        "model = tf.keras.Sequential([\n",
        "                             tf.keras.layers.Dense(100,activation='relu'),\n",
        "                             tf.keras.layers.Dense(100,activation='relu'),\n",
        "                             tf.keras.layers.Dense(100,activation='relu'),\n",
        "                             tf.keras.layers.Dense(10),\n",
        "                             tf.keras.layers.Dense(1)\n",
        "\n",
        "])\n",
        "\n",
        "model.compile(loss=tf.keras.losses.mape,\n",
        "              optimizer=tf.keras.optimizers.Adam(learning_rate=0.01),\n",
        "              metrics=['mape'])\n",
        "\n",
        "history = model.fit(X_train,y_train, epochs=355,verbose=0)"
      ],
      "execution_count": 122,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "aQQg7fmIlukB",
        "outputId": "68785549-c704-40cb-eb7b-b0f4a476ecde"
      },
      "source": [
        "loss, mape = model.evaluate(X_test,y_test)"
      ],
      "execution_count": 123,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "15/15 [==============================] - 0s 1ms/step - loss: 9.0868 - mape: 9.0868\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "m9a9zU8Gqgi5",
        "outputId": "536fb38c-d31b-4afa-e664-94d5cefb7e62"
      },
      "source": [
        "print(f'loss: {loss}, accuracy: {100-mape}')"
      ],
      "execution_count": 124,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "loss: 9.08675765991211, accuracy: 90.91324234008789\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}